{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A Simple Mistral 7B Offline RAG with Reranker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/simple-rag-1-reranker.webp\" width=700>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1) Load the embedding model and LLM:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "901c4c15eb2045cfbe8d1419846e3847",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from llama_index.core import Settings, PromptTemplate\n",
    "from llama_index.core.embeddings import resolve_embed_model\n",
    "from llama_index.embeddings.huggingface import HuggingFaceEmbedding\n",
    "from llama_index.llms.huggingface import HuggingFaceLLM\n",
    "\n",
    "import torch\n",
    "\n",
    "Settings.embed_model = HuggingFaceEmbedding(model_name=\"BAAI/bge-small-en-v1.5\")\n",
    "\n",
    "Settings.llm = HuggingFaceLLM(\n",
    "    context_window=2048,\n",
    "    max_new_tokens=256,\n",
    "    generate_kwargs={\"temperature\": 0.25, \"do_sample\": False},\n",
    "    model_name=\"mistralai/Mistral-7B-Instruct-v0.2\",\n",
    "    tokenizer_name=\"mistralai/Mistral-7B-Instruct-v0.2\",\n",
    "    device_map=\"auto\",\n",
    "    \n",
    "    model_kwargs={\n",
    "        \"torch_dtype\": torch.bfloat16,\n",
    "        # Since we are using a small GPU with limited memory\n",
    "        # Set to False if you have a large GPU to speed things up.\n",
    "        \"offload_buffers\": True, \n",
    "    }\n",
    ")\n",
    "\n",
    "Settings.chunk_size = 512"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2) Load data / read in documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Read documents: {'Basic-Scientific-Food-Preparation-Lab-Manual.txt', 'Basic-Scientific-Food-Preparation-Lab-Manual.pdf'}\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader\n",
    "\n",
    "data_dir =  Path(\"RAGs\") / \"data\" / \"example-1\"\n",
    "\n",
    "documents = SimpleDirectoryReader(data_dir).load_data()\n",
    "\n",
    "# Sanity checks\n",
    "unique_docs = set(d.metadata[\"file_name\"] for d in documents)\n",
    "print(f\"Read documents: {unique_docs}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3) Create vector database\n",
    "\n",
    "- VectorStoreIndex is an in-memory vector database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c00da099b6544642ad24220fabbcae1d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Parsing nodes:   0%|          | 0/252 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "427a151da7664536ad7e13b6317a6a3c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating embeddings:   0%|          | 0/567 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Database consists of 252 chunks\n"
     ]
    }
   ],
   "source": [
    "index = VectorStoreIndex.from_documents(\n",
    "    documents,\n",
    "    show_progress=True\n",
    ")\n",
    "\n",
    "num_chunks = len(documents)\n",
    "print(f\"Database consists of {num_chunks} chunks\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4) Set up query engine with reranker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.postprocessor import LLMRerank\n",
    "\n",
    "query_engine = index.as_query_engine(\n",
    "    similarity_top_k=6,\n",
    "    node_postprocessors=[\n",
    "        LLMRerank(\n",
    "            choice_batch_size=3,\n",
    "            top_n=2,\n",
    "        )\n",
    "    ],\n",
    "    response_mode=\"tree_summarize\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5) Set up custom prompt template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import PromptTemplate\n",
    "\n",
    "template = (\n",
    "    \"We have provided context information below. \\n\"\n",
    "    \"---------------------\\n\"\n",
    "    \"{context_str}\"\n",
    "    \"\\n---------------------\\n\"\n",
    "    \"Given this information, please answer the question: {query_str}\\n\"\n",
    ")\n",
    "qa_template = PromptTemplate(template)\n",
    "query_engine.update_prompts(\n",
    "    {\"response_synthesizer:text_qa_template\": qa_template}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6) Query the vector database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n",
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 tbsp. butter is required for the Cream Puffs.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\"How many tbsp. butter to use for the Cream Puffs?\")\n",
    "print(response)\n",
    "# Correct answer is 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n",
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1. Caraway seeds and shredded Cheddar cheese. 2. Diced Swiss cheese and paprika.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\n",
    "    \"What are the toppings for the Braided Bread \"\n",
    "    \"of the Braids, Coffeecake, and Sweet Rolls recipe?\"\n",
    ")\n",
    "print(response)\n",
    "# Correct answer is \n",
    "# 2 tsp. caraway seeds and ½ cup shredded Cheddar cheese.\n",
    "# ½ cup diced Swiss cheese and paprika.\n",
    "# ..."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
